{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4ee2125b-f889-47e6-9c3d-8bd63a253683",
   "metadata": {},
   "source": [
    "# Testing PySpark\n",
    "\n",
    "This guide is a reference for writing robust tests for PySpark code.\n",
    "\n",
    "To view the docs for PySpark test utils, see here. To see the code for PySpark built-in test utils, check out the Spark repository here. To see the JIRA board tickets for the PySpark test framework, see here."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e8ee4b6-9544-45e1-8a91-e71ed8ef8b9d",
   "metadata": {},
   "source": [
    "## Build a PySpark Application\n",
    "Here is an example for how to start a PySpark application. Feel free to skip to the next section, “Testing your PySpark Application,” if you already have an application you’re ready to test.\n",
    "\n",
    "First, start your Spark Session."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9af4a35b-17e8-4e45-816b-34c14c5902f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql import SparkSession \n",
    "from pyspark.sql.functions import col \n",
    "\n",
    "# Create a SparkSession \n",
    "spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate() "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a4c6efe-91f5-4e18-b4b2-b0401c2368e4",
   "metadata": {},
   "source": [
    "Next, create a DataFrame."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3b483dd8-3a76-41c6-9206-301d7ef314d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_data = [{\"name\": \"John    D.\", \"age\": 30}, \n",
    "  {\"name\": \"Alice   G.\", \"age\": 25}, \n",
    "  {\"name\": \"Bob  T.\", \"age\": 35}, \n",
    "  {\"name\": \"Eve   A.\", \"age\": 28}] \n",
    "\n",
    "df = spark.createDataFrame(sample_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0f44333-0e08-470b-9fa2-38f59e3dbd63",
   "metadata": {},
   "source": [
    "Now, let’s define and apply a transformation function to our DataFrame."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a6c0b766-af5f-4e1d-acf8-887d7cf0b0b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---+--------+\n",
      "|age|    name|\n",
      "+---+--------+\n",
      "| 30| John D.|\n",
      "| 25|Alice G.|\n",
      "| 35|  Bob T.|\n",
      "| 28|  Eve A.|\n",
      "+---+--------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.sql.functions import col, regexp_replace\n",
    "\n",
    "# Remove additional spaces in name\n",
    "def remove_extra_spaces(df, column_name):\n",
    "    # Remove extra spaces from the specified column\n",
    "    df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), \"\\\\s+\", \" \"))\n",
    "    \n",
    "    return df_transformed\n",
    "\n",
    "transformed_df = remove_extra_spaces(df, \"name\")\n",
    "\n",
    "transformed_df.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "530beaa6-aabf-43a1-ad2b-361f267e9608",
   "metadata": {},
   "source": [
    "## Testing your PySpark Application\n",
    "Now let’s test our PySpark transformation function. \n",
    "\n",
    "One option is to simply eyeball the resulting DataFrame. However, this can be impractical for large DataFrame or input sizes.\n",
    "\n",
    "A better way is to write tests. Here are some examples of how we can test our code. The examples below apply for Spark 3.5 and above versions.\n",
    "\n",
    "Note that these examples are not exhaustive, as there are many other test framework alternatives which you can use instead of `unittest` or `pytest`. The built-in PySpark testing util functions are standalone, meaning they can be compatible with any test framework or CI test pipeline.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d84a9fc1-9768-4af4-bfbf-e832f23334dc",
   "metadata": {},
   "source": [
    "### Option 1: Using Only PySpark Built-in Test Utility Functions\n",
    "\n",
    "For simple ad-hoc validation cases, PySpark testing utils like `assertDataFrameEqual` and `assertSchemaEqual` can be used in a standalone context.\n",
    "You could easily test PySpark code in a notebook session. For example, say you want to assert equality between two DataFrames:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8e533732-ee40-4cd0-9669-8eb92973908a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pyspark.testing\n",
    "from pyspark.testing.utils import assertDataFrameEqual\n",
    "\n",
    "# Example 1\n",
    "df1 = spark.createDataFrame(data=[(\"1\", 1000), (\"2\", 3000)], schema=[\"id\", \"amount\"])\n",
    "df2 = spark.createDataFrame(data=[(\"1\", 1000), (\"2\", 3000)], schema=[\"id\", \"amount\"])\n",
    "assertDataFrameEqual(df1, df2)  # pass, DataFrames are identical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2d77a6be-1e50-4c1a-8a44-85cf7dcec3f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example 2\n",
    "df1 = spark.createDataFrame(data=[(\"1\", 0.1), (\"2\", 3.23)], schema=[\"id\", \"amount\"])\n",
    "df2 = spark.createDataFrame(data=[(\"1\", 0.109), (\"2\", 3.23)], schema=[\"id\", \"amount\"])\n",
    "assertDataFrameEqual(df1, df2, rtol=1e-1)  # pass, DataFrames are approx equal by rtol"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76ade5f2-4a1f-4601-9d2a-80da9da950ff",
   "metadata": {},
   "source": [
    "You can also simply compare two DataFrame schemas:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "74393af5-40fb-4d04-87cb-265971ffe6d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.testing.utils import assertSchemaEqual\n",
    "from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType\n",
    "\n",
    "s1 = StructType([StructField(\"names\", ArrayType(DoubleType(), True), True)])\n",
    "s2 = StructType([StructField(\"names\", ArrayType(DoubleType(), True), True)])\n",
    "\n",
    "assertSchemaEqual(s1, s2)  # pass, schemas are identical"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c67be105-f6b1-4083-ad11-9e819331eae8",
   "metadata": {},
   "source": [
    "### Option 2: Using [Unit Test](https://docs.python.org/3/library/unittest.html)\n",
    "For more complex testing scenarios, you may want to use a testing framework.\n",
    "\n",
    "One of the most popular testing framework options is unit tests. Let’s walk through how you can use the built-in Python `unittest` library to write PySpark tests. For more information about the `unittest` library, see here: https://docs.python.org/3/library/unittest.html.  \n",
    "\n",
    "First, you will need a Spark session. You can use the `@classmethod` decorator from the `unittest` package to take care of setting up and tearing down a Spark session."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "54093761-0b49-4aee-baec-2d29bcf13f9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import unittest\n",
    "\n",
    "class PySparkTestCase(unittest.TestCase):\n",
    "    @classmethod\n",
    "    def setUpClass(cls):\n",
    "        cls.spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate() \n",
    "\n",
    "    \n",
    "    @classmethod\n",
    "    def tearDownClass(cls):\n",
    "        cls.spark.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3de27500-8526-412e-bf09-6927a760c5d7",
   "metadata": {},
   "source": [
    "Now let’s write a `unittest` class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "34feb5e1-944f-4f6b-9c5f-3b0bf68c7d05",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.testing.utils import assertDataFrameEqual\n",
    "\n",
    "class TestTranformation(PySparkTestCase):\n",
    "    def test_single_space(self):\n",
    "        sample_data = [{\"name\": \"John    D.\", \"age\": 30}, \n",
    "                       {\"name\": \"Alice   G.\", \"age\": 25}, \n",
    "                       {\"name\": \"Bob  T.\", \"age\": 35}, \n",
    "                       {\"name\": \"Eve   A.\", \"age\": 28}] \n",
    "                        \n",
    "        # Create a Spark DataFrame\n",
    "        original_df = spark.createDataFrame(sample_data)\n",
    "        \n",
    "        # Apply the transformation function from before\n",
    "        transformed_df = remove_extra_spaces(original_df, \"name\")\n",
    "        \n",
    "        expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n",
    "        {\"name\": \"Alice G.\", \"age\": 25}, \n",
    "        {\"name\": \"Bob T.\", \"age\": 35}, \n",
    "        {\"name\": \"Eve A.\", \"age\": 28}]\n",
    "        \n",
    "        expected_df = spark.createDataFrame(expected_data)\n",
    "    \n",
    "        assertDataFrameEqual(transformed_df, expected_df)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "319a690f-71bd-4886-bd3a-424e866525c2",
   "metadata": {},
   "source": [
    "When run, `unittest` will pick up all functions with a name beginning with “test.”"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d79e53d-cc1e-4fdf-a069-478337bed83d",
   "metadata": {},
   "source": [
    "### Option 3: Using [Pytest](https://docs.pytest.org/en/7.1.x/contents.html)\n",
    "\n",
    "We can also write our tests with `pytest`, which is one of the most popular Python testing frameworks. For more information about `pytest`, see the docs here: https://docs.pytest.org/en/7.1.x/contents.html.\n",
    "\n",
    "Using a `pytest` fixture allows us to share a spark session across tests, tearing it down when the tests are complete."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "60a4f304-1911-4b4d-8ed9-00ecc8b0890b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytest\n",
    "\n",
    "@pytest.fixture\n",
    "def spark_fixture():\n",
    "    spark = SparkSession.builder.appName(\"Testing PySpark Example\").getOrCreate()\n",
    "    yield spark"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcb4e26a-9bfc-48a5-8aca-538697d66642",
   "metadata": {},
   "source": [
    "We can then define our tests like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "fa5db3a1-7305-44b7-ab84-f5ed55fd2ba9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytest\n",
    "from pyspark.testing.utils import assertDataFrameEqual\n",
    "\n",
    "def test_single_space(spark_fixture):\n",
    "    sample_data = [{\"name\": \"John    D.\", \"age\": 30}, \n",
    "                   {\"name\": \"Alice   G.\", \"age\": 25}, \n",
    "                   {\"name\": \"Bob  T.\", \"age\": 35}, \n",
    "                   {\"name\": \"Eve   A.\", \"age\": 28}] \n",
    "                    \n",
    "    # Create a Spark DataFrame\n",
    "    original_df = spark.createDataFrame(sample_data)\n",
    "    \n",
    "    # Apply the transformation function from before\n",
    "    transformed_df = remove_extra_spaces(original_df, \"name\")\n",
    "    \n",
    "    expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n",
    "    {\"name\": \"Alice G.\", \"age\": 25}, \n",
    "    {\"name\": \"Bob T.\", \"age\": 35}, \n",
    "    {\"name\": \"Eve A.\", \"age\": 28}]\n",
    "    \n",
    "    expected_df = spark.createDataFrame(expected_data)\n",
    "\n",
    "    assertDataFrameEqual(transformed_df, expected_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fc3f394-3260-4e42-82cf-1a7edc859151",
   "metadata": {},
   "source": [
    "When you run your test file with the `pytest` command, it will pick up all functions that have their name beginning with “test.”"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8f50eee-5d0b-4719-b505-1b3ff05c16e8",
   "metadata": {},
   "source": [
    "## Putting It All Together!\n",
    "\n",
    "Let’s see all the steps together, in a Unit Test example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "a2ea9dec-0ac0-4c23-8770-d6cc226d2e97",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pkg/etl.py\n",
    "import unittest\n",
    "\n",
    "from pyspark.sql import SparkSession \n",
    "from pyspark.sql.functions import col\n",
    "from pyspark.sql.functions import regexp_replace\n",
    "from pyspark.testing.utils import assertDataFrameEqual\n",
    "\n",
    "# Create a SparkSession \n",
    "spark = SparkSession.builder.appName(\"Sample PySpark ETL\").getOrCreate() \n",
    "\n",
    "sample_data = [{\"name\": \"John    D.\", \"age\": 30}, \n",
    "  {\"name\": \"Alice   G.\", \"age\": 25}, \n",
    "  {\"name\": \"Bob  T.\", \"age\": 35}, \n",
    "  {\"name\": \"Eve   A.\", \"age\": 28}] \n",
    "\n",
    "df = spark.createDataFrame(sample_data)\n",
    "\n",
    "# Define DataFrame transformation function\n",
    "def remove_extra_spaces(df, column_name):\n",
    "    # Remove extra spaces from the specified column using regexp_replace\n",
    "    df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), \"\\\\s+\", \" \"))\n",
    "\n",
    "    return df_transformed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "248aede2-feb9-4828-bd9c-8e25e6b194ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pkg/test_etl.py\n",
    "import unittest\n",
    "\n",
    "from pyspark.sql import SparkSession \n",
    "\n",
    "# Define unit test base class\n",
    "class PySparkTestCase(unittest.TestCase):\n",
    "    @classmethod\n",
    "    def setUpClass(cls):\n",
    "        cls.spark = SparkSession.builder.appName(\"Sample PySpark ETL\").getOrCreate() \n",
    "\n",
    "    @classmethod\n",
    "    def tearDownClass(cls):\n",
    "        cls.spark.stop()\n",
    "        \n",
    "# Define unit test\n",
    "class TestTranformation(PySparkTestCase):\n",
    "    def test_single_space(self):\n",
    "        sample_data = [{\"name\": \"John    D.\", \"age\": 30}, \n",
    "                        {\"name\": \"Alice   G.\", \"age\": 25}, \n",
    "                        {\"name\": \"Bob  T.\", \"age\": 35}, \n",
    "                        {\"name\": \"Eve   A.\", \"age\": 28}] \n",
    "                \n",
    "        # Create a Spark DataFrame\n",
    "        original_df = spark.createDataFrame(sample_data)\n",
    "    \n",
    "        # Apply the transformation function from before\n",
    "        transformed_df = remove_extra_spaces(original_df, \"name\")\n",
    "    \n",
    "        expected_data = [{\"name\": \"John D.\", \"age\": 30}, \n",
    "        {\"name\": \"Alice G.\", \"age\": 25}, \n",
    "        {\"name\": \"Bob T.\", \"age\": 35}, \n",
    "        {\"name\": \"Eve A.\", \"age\": 28}]\n",
    "    \n",
    "        expected_df = spark.createDataFrame(expected_data)\n",
    "    \n",
    "        assertDataFrameEqual(transformed_df, expected_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a77df5b2-f32e-4d8c-a64b-0078dfa21217",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ran 1 test in 1.734s\n",
      "\n",
      "OK\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<unittest.main.TestProgram at 0x174539db0>"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unittest.main(argv=[''], verbosity=0, exit=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jupyter-oss-env",
   "language": "python",
   "name": "jupyter-oss-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
